import time
import torch
import torch.optim as optim
# import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter
from sparsemax import Sparsemax

import numpy as np
from collections import defaultdict
from collections import OrderedDict
import os
from os.path import join as pjoin

from data.utils import MotionNormalizerTorch, face_joint_indx, fid_l, fid_r
from data.quaternion import *
from utils.utils import print_current_loss
from eval import evaluation_during_training
from models_interhuman_selfattn.mask_transformer.tools import *
from timm.utils import ApexScaler, NativeScaler
from einops import rearrange, repeat
from spikingjelly.clock_driven import functional

def def_value():
    return 0.0

class MaskTransformerTrainer:
    def __init__(self, args, t2m_transformer, vq_model):
        self.opt = args
        self.t2m_transformer = t2m_transformer
        self.vq_model = vq_model
        self.accumulation_steps=args.accumulation_steps
        self.device = args.device
        self.vq_model.eval()
        self.normalizer = MotionNormalizerTorch(self.device)
        self.InteractionLoss = torch.nn.SmoothL1Loss(reduction='none')
        self.softmax = Sparsemax(dim=-1)

        if args.is_train:
            self.logger = SummaryWriter(args.log_dir)


    def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):

        current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
        for param_group in self.opt_t2m_transformer.param_groups:
            param_group["lr"] = current_lr

        return current_lr
    

    def calc_dm_loss(self, motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints, thresh_pred=1, thresh_tgt=0.1):
        pred_distance_matrix = torch.cdist(pred_motion1_joints.contiguous(), pred_motion2_joints)
        tgt_distance_matrix = torch.cdist(motion1_joints.contiguous(), motion2_joints)

        pred_distance_matrix = pred_distance_matrix.reshape(pred_distance_matrix.shape[0], -1) # T, njoints=22, 22 -> T, 484
        tgt_distance_matrix = tgt_distance_matrix.reshape(tgt_distance_matrix.shape[0], -1)

        dm_mask = (pred_distance_matrix < thresh_pred).float()
        dm_tgt_mask = (tgt_distance_matrix < thresh_tgt).float()
        
        dm_loss = (self.InteractionLoss(pred_distance_matrix, tgt_distance_matrix) * dm_mask).sum()/ (dm_mask.sum() + 1.e-7)
        dm_tgt_loss = (self.InteractionLoss(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix)) * dm_tgt_mask).sum()/ (dm_tgt_mask.sum() + 1.e-7)
        
        return dm_loss + dm_tgt_loss
    
    def calc_ro_loss(self, motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints):
        motion_joints = torch.cat([motion1_joints.unsqueeze(1), motion2_joints.unsqueeze(1)], dim=1)
        pred_motion_joints = torch.cat([pred_motion1_joints.unsqueeze(1), pred_motion2_joints.unsqueeze(1)], dim=1)

        r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
        across = pred_motion_joints[..., r_hip, :] - pred_motion_joints[..., l_hip, :]
        across = across / across.norm(dim=-1, keepdim=True)
        across_gt = motion_joints[..., r_hip, :] - motion_joints[..., l_hip, :]
        across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)

        y_axis = torch.zeros_like(across)
        y_axis[..., 1] = 1

        forward = torch.cross(y_axis, across, axis=-1)
        forward = forward / forward.norm(dim=-1, keepdim=True)
        forward_gt = torch.cross(y_axis, across_gt, axis=-1)
        forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)

        pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :])
        tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :])

        ro_loss = self.InteractionLoss(pred_relative_rot[..., [0, 2]],
                                    tgt_relative_rot[..., [0, 2]]).mean()

        return ro_loss

    def calc_interaction_loss(self, motion1, motion2, logits, id_lens):
        nbp = 5
        nt = self.opt.num_tokens
        m_lens = id_lens * 4
        # print(id_lens, m_lens)
        
        # denormalize input motions
        motions = torch.cat([motion1.unsqueeze(-2), motion2.unsqueeze(-2)], dim=-2)
        motions = self.normalizer.backward(motions)
        motion1_denorm, motion2_denorm = motions.chunk(2,dim=-2)
        # print(motion1.shape, motion2.shape)

        # get probs from logits
        # probs = logits.softmax(dim=-1)
        probs = self.softmax(logits)
        # print(probs.shape, m_lens)
        probs1, probs2 = probs.chunk(2, dim =1)
        # print(probs1.shape, probs.shape)
        # print(probs1.grad_fn, probs2.grad_fn)        
        
        dm_loss = 0
        ro_loss = 0
        j_loss = 0
        for i in range(len(id_lens)):    
            # vq decode
            # print('\n')
            # print(i)
            # print(probs1[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2).shape)
            pred_motion1 = self.vq_model.forward_decoder(probs1[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2), soft_lookup=True)
            # print(pred_motion1.shape)
            # print(pred_motion1.grad_fn)
            # print(probs2[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2).shape)
            pred_motion2 = self.vq_model.forward_decoder(probs2[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2), soft_lookup=True)
            # print(pred_motion2.shape)
            # print(pred_motion2.grad_fn)

            # # denormalize predicted motions
            pred_motion = torch.cat([pred_motion1.unsqueeze(-2), pred_motion2.unsqueeze(-2)], dim=-2)
            pred_motion = self.normalizer.backward(pred_motion)
            pred_motion1_denorm, pred_motion2_denorm = pred_motion.chunk(2, dim=-2)
            # print(pred_motion1.shape, pred_motion2.shape)
            # print(pred_motion1.grad_fn, pred_motion2.grad_fn)

            # get joints
            motion1_joints = motion1_denorm[i, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            motion2_joints = motion2_denorm[i, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            pred_motion1_joints = pred_motion1_denorm[0, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            pred_motion2_joints = pred_motion2_denorm[0, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            # print(motion1_joints.shape, motion2_joints.shape, pred_motion1_joints.shape, pred_motion2_joints.shape)
            # print(pred_motion1_joints.grad_fn, pred_motion2_joints.grad_fn)
            
            # calc losses
            dm_loss += self.calc_dm_loss(motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints) # torch.tensor(0).to(motion1.device) #
            ro_loss += self.calc_ro_loss(motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints) # torch.tensor(0).to(motion1.device) #
            # j_loss += self.InteractionLoss(pred_motion1_joints, motion1_joints).mean() + self.InteractionLoss(pred_motion2_joints, motion2_joints).mean()
            j_loss += self.InteractionLoss(pred_motion1[0, :m_lens[i], :], motion1[i, :m_lens[i], :]).mean() + self.InteractionLoss(pred_motion2[0, :m_lens[i], :], motion2[i, :m_lens[i], :]).mean()
            # print(j_loss.grad_fn)

        dm_loss = dm_loss / len(m_lens)
        ro_loss = ro_loss / len(m_lens)
        j_loss = j_loss / len(m_lens)
        # exit()
        return dm_loss, ro_loss, j_loss

    def get_vq_codes(self, motion1, motion2, m_lens):
        B = motion1.shape[0]
        T = motion1.shape[1]//4*5*2
        Q = 1
        code_idx_all = torch.zeros(B, T, Q, dtype=torch.int64).to(self.device)
        
        def pad_codes(code_idx):
            code_idx = code_idx.reshape(1, 5, -1, Q)
            code_idx = torch.cat([code_idx, 
                                  -1*torch.zeros(1, 5, T//10 - code_idx.shape[2], Q, dtype=torch.int64).to(self.device)], dim=2)
            code_idx = code_idx.reshape(1, -1, Q)
            return code_idx 
        
        for i in range(B):
            code_idx1, _ = self.vq_model.encode(motion1[i].unsqueeze(0)[:, :m_lens[i].item()])
            code_idx2, _ = self.vq_model.encode(motion2[i].unsqueeze(0)[:, :m_lens[i].item()])
            
            code_idx1 = pad_codes(code_idx1)
            code_idx2 = pad_codes(code_idx2)
            
            code_idx = torch.cat([code_idx1, code_idx2], dim=1)
            # print(f"Code Index: {code_idx1.shape}, {code_idx2.shape}, {code_idx.shape}, {code_idx_all.shape}")
            code_idx_all[i] = code_idx
            
        
        return code_idx_all
    
    def forward(self, batch_data):
        
        if self.opt.dataset_name == "interhuman":
            name, conds, motion1, motion2, m_lens = batch_data
        elif self.opt.dataset_name == "interx":
            _, _, conds, _, motions, m_lens, _ = batch_data
            # motions = motions.reshape(motions.shape[0], motions.shape[1], motions.shape[2]//12, 12)
            motion1, motion2 = motions.split(6, dim=-1)

        motion1 = motion1.detach().float().to(self.device)
        motion2 = motion2.detach().float().to(self.device)
        m_lens = m_lens.detach().long().to(self.device)
        # print(f"Motions from dataset: {motion1.shape}, {motion2.shape}")
        # print(f"Motion lenghts: {m_lens}")
        
        code_idx1, _ = self.vq_model.encode(motion1)
        code_idx2, _ = self.vq_model.encode(motion2)
        code_idx = torch.cat([code_idx1, code_idx2], dim=1)
        # print(f"Code Index: {code_idx1.shape}, {code_idx2.shape}, {code_idx.shape}")
        # code_idx = self.get_vq_codes(motion1, motion2, m_lens)
        
        conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds

        m_lens = m_lens // 4
        # print(f"Motion Lengths: {m_lens}")

        _loss, _acc, _, _, _ = self.t2m_transformer(code_idx[..., 0], conds, m_lens)
        return _loss, _acc
        
       

    def update_old(self, batch_data):
        loss, acc = self.forward(batch_data)

        self.opt_t2m_transformer.zero_grad()
        loss.backward()
        self.opt_t2m_transformer.step()
        self.scheduler.step()

        return loss.item(), acc
    def update(self, batch_data):
        loss, acc = self.forward(batch_data)

        # self.opt_t2m_transformer.zero_grad()
        # loss.backward()
        # self.opt_t2m_transformer.step()
        # self.scheduler.step()

        return loss, acc

    def save(self, file_name, ep, total_it):
        t2m_trans_state_dict = self.t2m_transformer.state_dict()
        clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_')]
        for e in clip_weights:
            del t2m_trans_state_dict[e]
        state = {
            't2m_transformer': t2m_trans_state_dict,
            'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(),
            # 'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None ,
            'ep': ep,
            'total_it': total_it,
        }
        torch.save(state, file_name)

    def resume(self, model_dir):
        checkpoint = torch.load(model_dir, map_location=self.device)
        missing_keys, unexpected_keys = self.t2m_transformer.load_state_dict(checkpoint['t2m_transformer'], strict=False)
        assert len(unexpected_keys) == 0
        assert all([k.startswith('clip_model.') for k in missing_keys])
        
        self.opt_t2m_transformer.load_state_dict(checkpoint['opt_t2m_transformer']) # Optimizer
        try:
            self.scheduler.load_state_dict({key: checkpoint['scheduler'][key] for key in ["last_epoch", "_step_count"]}) # Scheduler
        except:
            print('Resume wo optimizer')
        return checkpoint['ep'], checkpoint['total_it']

    def train(self, train_loader, val_loader, test_loader, eval_wrapper):
        self.t2m_transformer.to(self.device)
        self.vq_model.to(self.device)

        total_iters = self.opt.max_epoch * len(train_loader)
        self.opt.milestones = [int(total_iters * 0.5), int(total_iters * 0.7), int(total_iters * 0.85)]
        self.opt.warm_up_iter = len(train_loader) // 4
        self.opt.log_every = len(train_loader) // 10
        self.opt.save_latest = len(train_loader) // 2
        print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
        print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
        print(f'Milestones: {self.opt.milestones}')
        print('Warm Up Iterations: %04d, Log Every: %04d, Save Latest: %04d' % (self.opt.warm_up_iter, self.opt.log_every, self.opt.save_latest))

        self.opt_t2m_transformer = optim.AdamW(self.t2m_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_t2m_transformer, milestones=self.opt.milestones, gamma=self.opt.gamma)
        self.scaler=NativeScaler()
        
        epoch = 0
        it = 0

        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, 'latest.tar')
            epoch, it = self.resume(model_dir)
            it = it // self.opt.log_every * self.opt.log_every
            print("Load model epoch:%d iterations:%d" % (epoch, it))

        start_time = time.time()
        logs = defaultdict(def_value, OrderedDict())

        max_acc = -np.inf
        min_loss = np.inf
        min_fid = np.inf
        max_top1 = -np.inf

        if self.opt.do_eval:
            eval_file = pjoin(self.opt.eval_dir, 'evaluation_training.log')

        # Gradient accumulation settings
        accumulation_steps = self.accumulation_steps
        # assert accumulation_steps > 0, "accumulation_steps must be greater than 0."

        while epoch < self.opt.max_epoch:
            epoch += 1
            self.t2m_transformer.train()
            self.vq_model.eval()

            if epoch > 200:
                self.opt.eval_every_e = 10

            for i, batch in enumerate(train_loader):
                functional.reset_net(self.t2m_transformer)
                it += 1
                if it < self.opt.warm_up_iter:
                    self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)

                # Zero gradients for accumulation
                if (i % accumulation_steps) == 0:
                    self.opt_t2m_transformer.zero_grad()
                    
                loss, acc = self.update(batch_data=batch)
                loss/=self.accumulation_steps
                # Backward pass and gradient accumulation
                self.scaler._scaler.scale(loss).backward()
                # loss.backward()

                if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                    # Perform optimizer step only after accumulation_steps iterations
                    # self.opt_t2m_transformer.step()
                    self.scaler._scaler.step(self.opt_t2m_transformer)
                    self.scaler._scaler.update()
                    self.scheduler.step()


                logs['loss'] += loss.item()
                logs['acc'] += acc
                logs['lr'] += self.opt_t2m_transformer.param_groups[0]['lr']

                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    for tag, value in logs.items():
                        self.logger.add_scalar('Train/%s' % tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = defaultdict(def_value, OrderedDict())
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            print('Validation time:')
            self.vq_model.eval()
            self.t2m_transformer.eval()

            val_loss = []
            val_acc = []
            with torch.no_grad():
                for i, batch_data in enumerate(val_loader):
                    loss, acc = self.forward(batch_data)
                    val_loss.append(loss.item())
                    val_acc.append(acc)

            print(f"Validation loss:{np.mean(val_loss):.3f}, accuracy:{np.mean(val_acc):.3f}")

            self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch)
            self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch)

            if np.mean(val_acc) > max_acc:
                print(f"Improved accuracy from {max_acc:.02f} to {np.mean(val_acc)}!!!")
                self.save(pjoin(self.opt.model_dir, 'best_acc.tar'), epoch, it)
                max_acc = np.mean(val_acc)

            if np.mean(val_loss) < min_loss:
                print(f"Improved Loss from {min_loss:.02f} to {np.mean(val_loss)}!!!")
                self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
                min_loss = np.mean(val_loss)

            if self.opt.do_eval:
                if epoch % self.opt.eval_every_e == 0:
                    self.vq_model.eval()
                    self.t2m_transformer.eval()
                    fid, mat, top1 = evaluation_during_training(self.opt, self.vq_model, test_loader, eval_wrapper, epoch, eval_file, trans=self.t2m_transformer)
                    self.logger.add_scalar('Test/FID', fid, epoch)
                    self.logger.add_scalar('Test/Matching', mat, epoch)
                    self.logger.add_scalar('Test/Top1', top1, epoch)
                    if fid < min_fid:
                        min_fid = fid
                        self.save(pjoin(self.opt.model_dir, 'best_fid.tar'), epoch, it)
                        print('Best FID Model So Far!~')
                    if top1 > max_top1:
                        max_top1 = top1
                        self.save(pjoin(self.opt.model_dir, 'best_top1.tar'), epoch, it)
                        print('Best Top1 Model So Far!~')

            print('\n')
